import torch
from torch import nn
from model.towers import TwoTowerModel
from serving.faiss_indexer import FaissIndexer
from config import Config

if __name__ == "__main__":
    config = Config()

    model = TwoTowerModel(config)
    model.load_state_dict(torch.load(config.TOWERS_MODEL_SAVE_PATH))
    item_tower = model.item_tower
    item_tower.eval()
    # 构建Faiss索引
    
    indexer = FaissIndexer(config)
    indexer.build_index(item_tower)